#!/usr/bin/env python
# (C) 2000 Huaiyu Zhu <hzhu@users.sourceforge.net>.  Licence: GPL
# $I$
"""
Auxiliary stuff for MatPy:
	formating
	axes removal
"""

#------------------------------------------------------------------
# Formating functions
"??? Is there a better way to define format for complex numbers?"
import string
def shortI(x):	return "% -3d" % x
def shortF(x):	return "% -6.3g" % x
def shortC(x):	return "%8s" % ("% .3g% +.3gj" % (x.real, x.imag))

def longI(x):	return "%-11d" % x
def longF(x):	return "%-11.8g" % x
def longC(x):	return "%19s" % ("% .8g% +.8gj" % (x.real, x.imag))

#------------------------------------------------------------------
formatI = shortI
formatF = shortF
formatC = shortC

def reprI(seq, n=0):
	"Return formatted representation of multiarray"
	if len(seq.shape) <= 1:
		return "[ " + string.join(map(formatI, list(seq)),', ') + " ]"
	else:
		return "[" + string.join(
			map(lambda x,n=n:reprI(x,n+1), seq), ',\n'+' '*(n+1)) + "]"

def strI(seq, n=0):
	"Return formated string from multiarray"
	if len(seq.shape) <= 1:
		return string.join(map(formatI, list(seq)),' ')
	else:
		return "[" + string.join(
			map(lambda x,n=n:strI(x,n+1), seq), '\n'+' '*(n+1)) + " ]"
	return s

def reprF(seq, n=0):
	"Return formated representation of multiarray"
	if len(seq.shape) <= 1:
		return "[ " + string.join(map(formatF, list(seq)),', ') + " ]"
	else:
		return "[" + string.join(
			map(lambda x,n=n:reprF(x,n+1), seq), ',\n'+' '*(n+1)) + "]"

def strF(seq, n=0):
	"Return formated string from multiarray"
	if len(seq.shape) <= 1:
		return string.join(map(formatF, list(seq)),' ')
	else:
		return "[" + string.join(
			map(lambda x,n=n:strF(x,n+1), seq),'\n'+' '*(n+1)) + " ]"
	return s

def reprC(seq, n=0):
	"Return formated representation of multiarray (complex)"
	if len(seq.shape) <= 1:
		return "[ " + string.join(map(formatC, list(seq)),', ') + "]"
	else:
		return "[" + string.join(
			map(lambda x,n=n:reprC(x,n+1), seq), ',\n'+' '*(n+1)) + "]"

def strC(seq, n=0):
	"Return formated string from multiarray (complex)"
	if len(seq.shape) <= 1:
		return string.join(map(formatC, list(seq)),' ')
	else:
		return "[" + string.join(
			map(lambda x,n=n:strC(x,n+1), seq), '\n'+' '*(n+1)) + " ]"

#------------------------------------------------------------------
import JNumeric

def compactAxes(A):
	"""Return an array with axes of length 1 removed"""
	index = []
	for length in JNumeric.shape(A):
		if length == 1:	index.append(0)
		else:			index.append(JNumeric.slice(None))
	return A[tuple(index)]

def DelAxes(m):
	"""	Removes all axes with length one"""
	new_shape = []
	for length in m.shape:
		if length > 1:	new_shape.append(length)
	return JNumeric.reshape(m, new_shape)

